{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 139,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 140,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 特征0：0：青年，1：中年，2：老年\n",
    "# 特征1：0：无工作，1：有工作\n",
    "# 特征2：0：无房产，1：有房产\n",
    "# 特征3：0：信贷一般，1：信贷好，2：信贷非常好\n",
    "X = np.array([[0,0,0,0],\n",
    "              [0,0,0,1],\n",
    "              [0,1,0,1],\n",
    "              [0,1,1,0],\n",
    "              [0,0,0,0],\n",
    "              [1,0,0,0],\n",
    "              [1,0,0,1],\n",
    "              [1,1,1,1],\n",
    "              [1,0,1,2],\n",
    "              [1,0,1,2],\n",
    "              [2,0,1,2],\n",
    "              [2,0,1,1],\n",
    "              [2,1,0,1],\n",
    "              [2,1,0,2],\n",
    "              [2,0,0,0]])\n",
    "# 标签：0：不贷款，1：贷款\n",
    "y = np.array([0,0,1,1,0,0,0,1,1,1,1,1,1,1,0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 141,
   "metadata": {},
   "outputs": [],
   "source": [
    "def multi(y):\n",
    "    ySet = set(y)\n",
    "    bestCount = 0\n",
    "    for yi in ySet:\n",
    "        count = list(y).count(yi)\n",
    "        if count > bestCount:\n",
    "            bestCount = count\n",
    "            bestyi = yi\n",
    "    return bestyi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 142,
   "metadata": {},
   "outputs": [],
   "source": [
    "def gini(y):\n",
    "    ySet = set(y)\n",
    "    ret, n = 1, y.shape[0]\n",
    "    for yi in ySet:\n",
    "        ret -= (y[y==yi].shape[0]/n)**2\n",
    "    return ret\n",
    "\n",
    "def CART(X, y):\n",
    "    # 若D中所有实例属于同一类$$C_k$$\n",
    "    if len(set(y))==1:\n",
    "        # 将类$$C_k$$作为该结点的类标记\n",
    "        return {'label':y[0], 'gini':0,'Tt':1,'y':y.copy(),'cost':0}\n",
    "    bestGini = np.inf\n",
    "    # 对每个特征feature的每个取值value\n",
    "    for feature in range(X.shape[1]):\n",
    "        for value in set(X[:,feature]):\n",
    "            # 将X分为$$R_1$$和$$R_2$$两个集合\n",
    "            y1 = y[X[:,feature]<= value]\n",
    "            y2 = y[X[:,feature]> value]\n",
    "            # 计算$$R_1$$和$$R_2$$的基尼指数之和\n",
    "            sumGini = y1.shape[0]/y.shape[0]*gini(y1) + y2.shape[0]/y.shape[0]*gini(y2)\n",
    "            # 选择基尼指数计算结果最小的(feature, value)作为当前的最优划分\n",
    "            if sumGini < bestGini:\n",
    "                bestFeature, bestValue, bestGini = feature, value, sumGini\n",
    "    # 基于最优划分生成2个子结点，将数据分配到两个子结点中\n",
    "    node = {'feature':bestFeature,\n",
    "            'value':bestValue,\n",
    "            'gini':gini(y),\n",
    "            'y':y.copy(),\n",
    "            'label':multi(y),\n",
    "            'left':CART(X[X[:,bestFeature]<= bestValue], y[X[:,bestFeature]<= bestValue]),\n",
    "           'right':CART(X[X[:,bestFeature]> bestValue], y[X[:,bestFeature]> bestValue])}\n",
    "    node['Tt'] = node['left']['Tt'] + node['right']['Tt']\n",
    "    node['cost'] = node['left']['Tt']/node['Tt']*node['left']['cost']\n",
    "    + node['right']['Tt']/node['Tt']*node['right']['cost']\n",
    "    return node\n",
    "   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 143,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'feature': 2,\n",
       " 'value': 0,\n",
       " 'gini': 0.48,\n",
       " 'y': array([0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0]),\n",
       " 'label': 1,\n",
       " 'left': {'feature': 1,\n",
       "  'value': 0,\n",
       "  'gini': 0.4444444444444445,\n",
       "  'y': array([0, 0, 1, 0, 0, 0, 1, 1, 0]),\n",
       "  'label': 0,\n",
       "  'left': {'label': 0,\n",
       "   'gini': 0,\n",
       "   'Tt': 1,\n",
       "   'y': array([0, 0, 0, 0, 0, 0]),\n",
       "   'cost': 0},\n",
       "  'right': {'label': 1, 'gini': 0, 'Tt': 1, 'y': array([1, 1, 1]), 'cost': 0},\n",
       "  'Tt': 2,\n",
       "  'cost': 0.0},\n",
       " 'right': {'label': 1,\n",
       "  'gini': 0,\n",
       "  'Tt': 1,\n",
       "  'y': array([1, 1, 1, 1, 1, 1]),\n",
       "  'cost': 0},\n",
       " 'Tt': 3,\n",
       " 'cost': 0.0}"
      ]
     },
     "execution_count": 143,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tree = CART(X, y)\n",
    "tree"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 144,
   "metadata": {},
   "outputs": [],
   "source": [
    "def isLeaf(Node):\n",
    "    # return type(Node).__name__ != 'dict'\n",
    "    return Node['Tt']==1\n",
    "    \n",
    "def calcAlphaList(Node):\n",
    "    if isLeaf(Node):\n",
    "        return\n",
    "    costNotSplit = Node['gini']\n",
    "    costSplit = Node['cost']\n",
    "    alpha = (costNotSplit-costSplit)/(Node['Tt']-1)\n",
    "    if (alpha<0):\n",
    "        print (Node['Tt'], Node['gini'], Node['y'])\n",
    "        print (Node['right']['Tt'], Node['right']['gini'], Node['left']['y'])\n",
    "        print (Node['left']['Tt'], Node['left']['gini'], Node['right']['y'])\n",
    "    Node['alpha'] = alpha\n",
    "    if alpha < calcAlphaList.bestAlpha:\n",
    "        calcAlphaList.bestAlpha = alpha\n",
    "    calcAlphaList(Node['left'])\n",
    "    calcAlphaList(Node['right'])\n",
    "    \n",
    "def calcTt(Node):\n",
    "    if isLeaf(Node):\n",
    "        return 1\n",
    "    Node['Tt'] = calcTt(Node['left']) + calcTt(Node['right'])\n",
    "    return Node['Tt']\n",
    "\n",
    "def cut(Node, alpha):\n",
    "    if isLeaf(Node):\n",
    "        return\n",
    "    if Node['alpha'] == alpha:\n",
    "        Node.pop('left')\n",
    "        Node.pop('right')\n",
    "        Node.pop('feature')\n",
    "        Node.pop('value')\n",
    "        Node['cost'] = Node['gini']\n",
    "        Node['Tt'] = 1\n",
    "    else:\n",
    "        cut(Node['left'], alpha)\n",
    "        cut(Node['right'], alpha)\n",
    "        \n",
    "def prune(Tree):\n",
    "    TreeList = [Tree.copy()]\n",
    "    while not isLeaf(Tree):\n",
    "        calcTt(Tree)  #每次剪过之后结点会变少，所以Tt要重新计算\n",
    "        calcAlphaList.bestAlpha = np.inf\n",
    "        calcAlphaList(Tree)\n",
    "        print ('alpha',calcAlphaList.bestAlpha)\n",
    "        cut(Tree, calcAlphaList.bestAlpha)\n",
    "        TreeList.append(Tree.copy())\n",
    "        return TreeList"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 145,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "alpha 0.24\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[{'feature': 2,\n",
       "  'value': 0,\n",
       "  'gini': 0.48,\n",
       "  'y': array([0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0]),\n",
       "  'label': 1,\n",
       "  'left': {'feature': 1,\n",
       "   'value': 0,\n",
       "   'gini': 0.4444444444444445,\n",
       "   'y': array([0, 0, 1, 0, 0, 0, 1, 1, 0]),\n",
       "   'label': 0,\n",
       "   'left': {'label': 0,\n",
       "    'gini': 0,\n",
       "    'Tt': 1,\n",
       "    'y': array([0, 0, 0, 0, 0, 0]),\n",
       "    'cost': 0},\n",
       "   'right': {'label': 1, 'gini': 0, 'Tt': 1, 'y': array([1, 1, 1]), 'cost': 0},\n",
       "   'Tt': 2,\n",
       "   'cost': 0.0,\n",
       "   'alpha': 0.4444444444444445},\n",
       "  'right': {'label': 1,\n",
       "   'gini': 0,\n",
       "   'Tt': 1,\n",
       "   'y': array([1, 1, 1, 1, 1, 1]),\n",
       "   'cost': 0},\n",
       "  'Tt': 3,\n",
       "  'cost': 0.0},\n",
       " {'gini': 0.48,\n",
       "  'y': array([0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0]),\n",
       "  'label': 1,\n",
       "  'Tt': 1,\n",
       "  'cost': 0.48,\n",
       "  'alpha': 0.24}]"
      ]
     },
     "execution_count": 145,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tree = CART(X, y)\n",
    "prune(tree)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 133,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD5CAYAAAA3Os7hAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+17YcXAAAXn0lEQVR4nO3dfYxcV3nH8d/j2RQwASI1qxLFL6sKhNREzotXISgVCjiteLGMKkBytZQaFbm1SQlqK9oQKRKWUFVValNI7WgJqkLjltDwUoNSWhKIGlRhtA4hEEwrg+LEhDabpCSlBirbT/+4s/Ls7MzsPTNz5p5z5vuRrnbmzsnd59x798n1uc89Y+4uAED+NjQdAABgPEjoAFAIEjoAFIKEDgCFIKEDQCFI6ABQiJm6Dc2sJWlJ0g/dfWfXZ3sk/bmkH7ZX3e7udw7a3sUXX+xzc3NBwQLAtDt27Ngz7j7b67PaCV3STZKOS3p5n8/vcfcb625sbm5OS0tLAb8eAGBmJ/t9VmvIxcw2SXqrpIFX3QCA5tQdQ79N0gclnRvQ5u1m9qiZ3Wtmm3s1MLO9ZrZkZkvLy8uhsQIABlg3oZvZTklPu/uxAc2+IGnO3bdJul/SXb0aufuiu8+7+/zsbM8hIADAkOpcoV8naZeZPS7pU5LeaGZ3dzZw92fd/efttx+XtH2sUQIA1rVuQnf3m919k7vPSdot6Svu/q7ONmZ2ScfbXapungIAJiikymUVMzsgacndj0h6v5ntknRG0nOS9ownPABAXUEPFrn7gys16O5+azuZr1zFX+buV7j7G9z9ezGCBRpx+LA0Nydt2FD9PHy46YiAnoa+QgemwuHD0t690unT1fuTJ6v3krSw0FxcQA88+g8Mcsst55P5itOnq/VAYkjowCBPPBG2HmgQCR0YZMuWsPVAg0jowCAf+Yi0cePqdRs3VuuBxJDQgUEWFqTFRWnrVsms+rm4yA1RJIkqF2A9CwskcGSBK3QAKAQJHQAKQUIHgEKQ0AGgECR0ACgECR0ACkFCB4BCkNABoBAkdAAoBAkd5eCLKDDlePQfZeCLKACu0FEIvogCIKGjEHwRBUBCRyH4IgqAhI5C8EUUAAkdheCLKACqXFAQvogCU44rdIyO+m8gCVyhYzTUfwPJ4Aodo6H+G0gGCR2jof4bSAYJHaOh/htIBgkdo6H+G0gGCR2jof4bSAZVLhgd9d9AEmpfoZtZy8y+aWZf7PHZi8zsHjM7YWZHzWxunEEC2aAmHw0KGXK5SdLxPp/9jqT/dvdXSfpLSX82amBAdlZq8k+elNzP1+ST1DEhtRK6mW2S9FZJd/Zp8jZJd7Vf3ytph5nZ6OEBGaEmHw2re4V+m6QPSjrX5/NLJT0pSe5+RtLzkn6xu5GZ7TWzJTNbWl5eHiJcIGHU5KNh6yZ0M9sp6Wl3PzaoWY91vmaF+6K7z7v7/OzsbECYQAaoyUfD6lyhXydpl5k9LulTkt5oZnd3tTklabMkmdmMpFdIem6McQLpoyYfDVs3obv7ze6+yd3nJO2W9BV3f1dXsyOSfrv9+h3tNmuu0IGiUZOPhg1dh25mByQtufsRSZ+Q9LdmdkLVlfnuMcUH5IWafDQo6ElRd3/Q3Xe2X9/aTuZy95+5+zvd/VXufo27/yBGsJgy+/dLMzPV1e7MTPUeQF88KYo07d8vHTp0/v3Zs+ffHzzYTExA4pjLBWlaXAxbD4CEjkSdPRu2HgAJHYlqtcLWAyChI1Er30tadz0AbooiUSs3PhcXq2GWVqtK5twQBfoioSNdBw+SwIEADLmgtxtuqOq/V5Ybbmg6ouYwxzkyQULHWjfcID3wwOp1DzwwnUmdOc6REWtqypX5+XlfWlpq5HdjHYOmsp+2KXrm5qok3m3rVunxxycdDSAzO+bu870+4wodGIQ5zpEREjowCHOcIyMkdKy1Y0fY+pIxxzkyQkLHWvffvzZ579hRrZ82zHGOjHBTFAAywk1RhItVex2yXeq/gSA8KYq1VmqvT5+u3q/UXkujDTWEbDdWDEDBGHLBWrFqr0O2S/030BNDLggTq/Y6ZLvUfwPBSOhYK1btdch2qf8GgpHQsVas2uuQ7VL/DQQjoWOtWLXXIdul/hsIxk1RAMgIN0VjSKFGOjSGFGIGEA116MNIoUY6NIYUYgYQFUMuw0ihRjo0hhRiBjAyhlzGLYUa6dAYUogZQFQk9GGkUCMdGkMKMQOIioQ+jBRqpENjSCFmAFGR0IeRQo10aAwpxAwgKm6KAkBGRropamYvNrNvmNm3zOwxM/twjzZ7zGzZzB5pL+8dR+AYs/37pZmZ6gp9ZqZ6P462qdS3pxIH0BR3H7hIMkkXtl9fIOmopGu72uyRdPt62+pctm/f7pigffvcpbXLvn2jtb37bveNG1e327ixWj9JqcQBRCZpyfvk1aAhFzPbKOlrkva5+9GO9Xskzbv7jXW3xZDLhM3MSGfPrl3faklnzgzfNpX69lTiACIbuQ7dzFpm9oikpyV9uTOZd3i7mT1qZvea2eY+29lrZktmtrS8vFy7AxiDXgm63/qQtqnUt6cSB9CgWgnd3c+6+5WSNkm6xswu72ryBUlz7r5N0v2S7uqznUV3n3f3+dnZ2VHiRqhWq/76kLap1LenEgfQoKCyRXf/saQHJb2pa/2z7v7z9tuPS9o+lugwPivzttRZH9I2lfr2VOIAmtRvcH1lkTQr6aL265dIekjSzq42l3S8/g1JX19vu9wUbcC+fe6tVnXDsNXqfZNzmLZ33+2+dau7WfWzqRuRqcQBRKRRboqa2TZVQygtVVf0n3b3A2Z2oL3hI2b2p5J2SToj6TlVN02/N2i73BQFgHAj3RR190fd/Sp33+bul7v7gfb6W939SPv1ze5+mbtf4e5vWC+ZFyFWzXNI/XfMbYf0L8d9kRlK7FFLv0v32EvWQy6xap5D6r9jbjukfznui8xQYo9OGlcd+jhlPeQSq+Y5pP475rZD+pfjvsgMJfboNGjIhYQ+jA0bqgulbmbSuXPDb9es/2ejHqeQbYf0L8d9kZlYuxh54gsuxi1WzXNI/XfMbYf0L8d9kRlK7FEXCX0YsWqeQ+q/Y247pH857ovMUGKP2voNrsdesr4p6h6v5jmk/jvmtkP6l+O+yAwl9lghbooCQBkYQ0clhdpyZI3TIm0zTQeACTl8uBp/Pn26en/y5Pnx6O6voQtpi6nBaZE+hlymRQq15cgap0UaGHJB2HzhzC2OHjgt0kdCnxYp1JYja5wW6SOhT4sUasuRNU6L9JHQp8XCgrS4WA14mlU/Fxd7380KaYupwWmRPm6KAkBGpvumaKzC2ZDtpjKvN0XESSn9cJTevxAT2xf9HiGNvUzk0f9YE0mHbDeVeb2ZVDsppR+O0vsXYtz7QlP76H+swtmQ7aYyrzdFxEkp/XCU3r8Q494X0zsfeqyJpEO2m8q83kyqnZTSD0fp/Qsx7n0xvWPosQpnQ7abyrzeFBEnpfTDUXr/QkxyX5Sd0GMVzoZsN5V5vSkiTkrph6P0/oWY6L7oN7gee5nYfOixJpIO2W4q83ozqXZSSj8cpfcvxDj3hab2pigAFGZ6x9Bjor4dyEKsP5Mk6+z7XbrHXrL+Cjrq24EsxPozabLOXgy5jBn17UAWYv2ZNFlnz5DLuMWaGDpku73O0kHrgSkU688k1bnhSejDoL4dyEKsP5NU6+xJ6MOgvh3IQqw/k2Tr7PsNrsdesr4p6k59O5CJWH8mTdXZi5uiAFCGkW6KmtmLzewbZvYtM3vMzD7co82LzOweMzthZkfNbG70sPsILf5Mslh0gJCi2cL3RcxwY+7mumL2L7NDHaTw0340/S7dVxZJJunC9usLJB2VdG1Xm/2S7mi/3i3pnvW2O9SQS2jxZ26TMocUzRa+L2KGG3M31xWzf5kd6iCFn/a1aMCQS9C4t6SNkh6W9Nqu9f8s6XXt1zOSnlF7at5+y1AJfevW3n+JW7eOp33TVgb6updWa23bwvdFzHBj7ua6YvYvs0MdpPDTvpZBCb3WGLqZtSQdk/QqSX/t7n/c9fl3JL3J3U+133+/nfSf6Wq3V9JeSdqyZcv2k70q8wcJnVg4t0mZQ+ZOL3xfxAw35m6uK2b/MjvUQQo/7WsZ+cEidz/r7ldK2iTpGjO7vPt39PrPemxn0d3n3X1+dna2zq9eLbT4M9Vi0X5CimYL3xcxw425m+uK2b/MDnWQwk/7kQXVobv7jyU9KOlNXR+dkrRZksxsRtIrJD03hvhWCy3+TLZYtI+QotnC90XMcGPu5rpi9i+zQx2k8NN+dP3GYlYWSbOSLmq/fomkhyTt7GrzPq2+Kfrp9bY7dB16aPFnbpMyhxTNFr4vYoYbczfXFbN/mR3qIIWf9uvSKGPoZrZN0l2SWqqu6D/t7gfM7EB7w0fM7MWS/lbSVaquzHe7+w8GbZc6dAAIN2gMfWa9/9jdH1WVqLvX39rx+meS3jlKkACA0ZQ/l8tUPVWAukJOixROoZgP0+T24FQKxyNZ/cZiYi8TmculxKcKMLKQ0yKFUyjmwzS5PTiVwvFomqZ2LpcmZ6FHskJOixROodAYUuhfbtvNyaAx9LITeolPFWBkIadFCqdQzIdpcntwKoXj0bTp/caiaXuqALWEnBYpnEIxH6bJ7cGpFI5HyspO6FP3VAHqCDktUjiFYj5Mk9uDUykcj6T1G1yPvUzsCy5Ke6oAYxFyWqRwCsV8mCa3B6dSOB5N0tTeFAWAwkzvGDowBiFfhpGK3GJOpbY8lTiG1u/SPfaS/XeKYiqEfBlGKnKLOZXa8lTiWI8YcgGGMzMjnT27dn2rJZ05M/l46sgt5lRqy1OJYz0MuQBD6pUYB61PQW4xP/FE2PrS4xgFCR0YIOTLMFKRW8yp1JanEscoSOjAACFfhpGK3GJOpbY8lThG0m9wPfbCTVHkIuTLMFKRW8yp1JanEscg4qYoAJSBm6KIKsfa3Vgxx6r/znEfowH9Lt1jLwy5lCGX2t1OsWKOVf+d4z5GPGLIBbHkUrvbKVbMseq/c9zHiIchF0STY+1urJhj1X/nuI/RDBI6RpJj7W6smGPVf+e4j9EMEjpGkmPtbqyYY9V/57iP0ZB+g+uxF26KliOH2t1usWKOVf+d4z5GHOKmKACUgZuimAqxarVDtku9OJo003QAwDgcPlyNVZ8+Xb0/efL82PXCwmS2GysGoC6GXFCEWLXaIdulXhyTwJALiherVjtku9SLo2kkdBQhVq12yHapF0fTSOgoQqxa7ZDtUi+OppHQUYSFBWlxsRqvNqt+Li6OfjMyZLuxYgDq4qYoAGRkpJuiZrbZzL5qZsfN7DEzu6lHm+vN7Hkze6S93DqOwNGcHOupqRePj/2WuH6PkK4ski6RdHX79csk/YekX+lqc72kL663rc6FR//TleP82yEx59i/FLDf0qBxPvpvZv8o6XZ3/3LHuusl/ZG776y7HYZc0pVjPTX14vGx39IwaMglKKGb2Zykf5V0ubu/0LH+ekmfkXRK0lOqkvtjPf77vZL2StKWLVu2n+x1dqBxGzZU11/dzKRz5yYfTx0hMefYvxSw39IwlgeLzOxCVUn7A53JvO1hSVvd/QpJH5P0+V7bcPdFd5939/nZ2dm6vxoTlmM9NfXi8bHf0lcroZvZBaqS+WF3/2z35+7+grv/pP36PkkXmNnFY40UE5NjPTX14vGx3zLQb3B9ZZFkkj4p6bYBbV6p88M310h6YuV9v4WbomnLcf7tkJhz7F8K2G/N0yg3Rc3sVyU9JOnbklZGyj4kaUv7fwh3mNmNkvZJOiPpp5L+wN3/bdB2uSkKAOFGGkN396+5u7n7Nne/sr3c5+53uPsd7Ta3u/tl7n6Fu1+7XjLHeFATvNr+/dLMTHWTbmameg9ME+ZDzxRzb6+2f7906ND592fPnn9/8GAzMQGTxqP/maImeLWZmSqJd2u1pDNnJh8PEAvzoReIubdX65XMB60HSkRCzxQ1wau1WmHrgRKR0DNFTfBqK/cP6q4HSkRCzxRzb6928KC0b9/5K/JWq3rPDVFME26KAkBGuClaV+GF3YV3r/j+pYB9nLh+j5DGXpJ79L/wyZ4L717x/UsB+zgNGud86OOS3JBL4YXdhXev+P6lgH2chrHNhz5OySX0wid7Lrx7xfcvBezjNDCGXkfhhd2Fd6/4/qWAfZw+EvqKwgu7C+9e8f1LAfs4A/0G12Mvyd0UdS9+sufCu1d8/1LAPm6euCkKAGVgDB3IXMz6b2rLy8F86EDiYs59z7z6ZWHIBUhczPpvasvzw5ALkLGYc98zr35ZSOhA4mLWf1NbXhYSOpC4mPXf1JaXhYQOJC7m3PfMq18WbooCQEa4KQoAU4CEDgCFIKEDQCFI6ABQCBI6ABSChA4AhSChA0AhSOgAUIh1E7qZbTazr5rZcTN7zMxu6tHGzOyjZnbCzB41s6vjhItRMO81ULY686GfkfSH7v6wmb1M0jEz+7K7f7ejzZslvbq9vFbSofZPJIJ5r4HyrXuF7u4/cveH26//R9JxSZd2NXubpE+2v/Lu65IuMrNLxh4thnbLLeeT+YrTp6v1AMoQNIZuZnOSrpJ0tOujSyU92fH+lNYmfZnZXjNbMrOl5eXlsEgxEua9BspXO6Gb2YWSPiPpA+7+QvfHPf6TNbN+ufuiu8+7+/zs7GxYpBgJ814D5auV0M3sAlXJ/LC7f7ZHk1OSNne83yTpqdHDw7gw7zVQvjpVLibpE5KOu/tf9Gl2RNK729Uu10p63t1/NMY4MSLmvQbKV6fK5TpJvyXp22b2SHvdhyRtkSR3v0PSfZLeIumEpNOS3jP+UDGqhQUSOFCydRO6u39NvcfIO9u4pPeNKygAQDieFAWAQpDQAaAQJHQAKAQJHQAKQUIHgEKQ0AGgECR0ACiEVSXkDfxis2VJJxv55eu7WNIzTQcREf3LV8l9k+hfHVvdvedkWI0l9JSZ2ZK7zzcdRyz0L18l902if6NiyAUACkFCB4BCkNB7W2w6gMjoX75K7ptE/0bCGDoAFIIrdAAoBAkdAAox1QndzFpm9k0z+2KPz/aY2bKZPdJe3ttEjKMws8fN7Nvt+Jd6fG5m9lEzO2Fmj5rZ1U3EOYwafbvezJ7vOH63NhHnsMzsIjO718y+Z2bHzex1XZ9ne+ykWv3L9viZ2Ws64n7EzF4wsw90tYly/Op8Y1HJbpJ0XNLL+3x+j7vfOMF4YniDu/d7kOHNkl7dXl4r6VD7Zy4G9U2SHnL3nROLZrz+StKX3P0dZvYLkrq+ETb7Y7de/6RMj5+7/7ukK6XqolHSDyV9rqtZlOM3tVfoZrZJ0lsl3dl0LA16m6RPeuXrki4ys0uaDmramdnLJb1e1Xf5yt3/z91/3NUs22NXs3+l2CHp++7e/VR8lOM3tQld0m2SPijp3IA2b2//c+heM9s8objGySX9i5kdM7O9PT6/VNKTHe9PtdflYL2+SdLrzOxbZvZPZnbZJIMb0S9LWpb0N+0hwTvN7KVdbXI+dnX6J+V7/DrtlvT3PdZHOX5TmdDNbKekp9392IBmX5A05+7bJN0v6a6JBDde17n71ar+efc+M3t91+e9vis2lzrW9fr2sKo5L66Q9DFJn590gCOYkXS1pEPufpWk/5X0J11tcj52dfqX8/GTJLWHknZJ+odeH/dYN/Lxm8qELuk6SbvM7HFJn5L0RjO7u7OBuz/r7j9vv/24pO2TDXF07v5U++fTqsbwrulqckpS5788Nkl6ajLRjWa9vrn7C+7+k/br+yRdYGYXTzzQ4ZySdMrdj7bf36sqAXa3yfLYqUb/Mj9+K94s6WF3/68en0U5flOZ0N39Znff5O5zqv5J9BV3f1dnm67xrF2qbp5mw8xeamYvW3kt6dclfaer2RFJ727fcb9W0vPu/qMJhxqsTt/M7JVmZu3X16g615+ddKzDcPf/lPSkmb2mvWqHpO92Ncvy2En1+pfz8evwm+o93CJFOn7TXuWyipkdkLTk7kckvd/Mdkk6I+k5SXuajG0IvyTpc+2/iRlJf+fuXzKz35Mkd79D0n2S3iLphKTTkt7TUKyh6vTtHZL2mdkZST+VtNvzeiz69yUdbv+z/QeS3lPIsVuxXv+yPn5mtlHSr0n63Y510Y8fj/4DQCGmcsgFAEpEQgeAQpDQAaAQJHQAKAQJHQAKQUIHgEKQ0AGgEP8P3x3UprO23Z8AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn import datasets\n",
    "\n",
    "iris = datasets.load_iris()\n",
    "X = iris.data\n",
    "y = iris.target\n",
    "\n",
    "X = X[y<2,:2]\n",
    "y = y[y<2]\n",
    "plt.scatter(X[y==0,0],X[y==0,1], color='red')\n",
    "plt.scatter(X[y==1,0],X[y==1,1], color='blue')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 136,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "alpha 0.02172839506172842\n",
      "2\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'feature': 0,\n",
       " 'value': 5.4,\n",
       " 'gini': 0.5,\n",
       " 'y': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "        0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]),\n",
       " 'label': 0,\n",
       " 'left': {'feature': 1,\n",
       "  'value': 2.7,\n",
       "  'gini': 0.20761245674740492,\n",
       "  'y': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 1, 1, 1, 1, 1, 1]),\n",
       "  'label': 0,\n",
       "  'left': {'feature': 0,\n",
       "   'value': 4.5,\n",
       "   'gini': 0.2777777777777777,\n",
       "   'y': array([0, 1, 1, 1, 1, 1]),\n",
       "   'label': 1,\n",
       "   'left': {'label': 0, 'gini': 0, 'Tt': 1, 'y': array([0])},\n",
       "   'right': {'label': 1, 'gini': 0, 'Tt': 1, 'y': array([1, 1, 1, 1, 1])},\n",
       "   'Tt': 2,\n",
       "   'alpha': 0.2777777777777777},\n",
       "  'right': {'gini': 0,\n",
       "   'y': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "          1]),\n",
       "   'label': 0,\n",
       "   'Tt': 1,\n",
       "   'alpha': 0.02172839506172842},\n",
       "  'Tt': 5,\n",
       "  'alpha': 0.02412533640907346},\n",
       " 'right': {'feature': 1,\n",
       "  'value': 3.4,\n",
       "  'gini': 0.18325697625989168,\n",
       "  'y': array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1]),\n",
       "  'label': 1,\n",
       "  'left': {'label': 1,\n",
       "   'gini': 0,\n",
       "   'Tt': 1,\n",
       "   'y': array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])},\n",
       "  'right': {'label': 0, 'gini': 0, 'Tt': 1, 'y': array([0, 0, 0, 0, 0])},\n",
       "  'Tt': 2,\n",
       "  'alpha': 0.18325697625989168},\n",
       " 'Tt': 7}"
      ]
     },
     "execution_count": 136,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def predictOne(tree, xTest):\n",
    "    if tree['Tt'] == 1:\n",
    "        return tree['label']\n",
    "    if xTest[tree['feature']] < tree['value']:\n",
    "        return predictOne(tree['left'], xTest)\n",
    "    else:\n",
    "        return predictOne(tree['right'], xTest)\n",
    "    \n",
    "def predict(tree, X, y):\n",
    "    yPredict = np.array([predictOne(tree, xTest) for xTest in X])\n",
    "    return np.sum(yPredict!=y)/y.shape[0]\n",
    "\n",
    "tree = CART(X, y)\n",
    "TreeList = prune(tree)\n",
    "print (len(TreeList))\n",
    "bestError = np.inf\n",
    "for tree in TreeList:\n",
    "    error = predict(tree, X, y)\n",
    "    if error < bestError:\n",
    "        bestError = error\n",
    "        bestTree = tree\n",
    "\n",
    "bestTree"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
